import sklearn
from evaluate.robust.base import Adv_Analysis
import wandb
import numpy as np
import pickle
import torch.nn as nn
import seaborn as sns
import matplotlib.pyplot as plt
import pandas as pd
import os, csv
from evaluate.robust.attacks import *
from tool.util import get_valid_unit

#sns.set(style="whitegrid")


class PerturbAnalysis(Adv_Analysis) :

    def __init__(self, attacks, model, dloader, rootdir, args, overwrite=True):
        super().__init__(model=model, dloader=dloader, rootdir=rootdir,
                         args=args, overwrite = overwrite)
        with open(os.path.join(rootdir, "raw_data.pkl"), "wb") as f:
            pickle.dump(dloader.dataset.data, f)
            pickle.dump(dloader.dataset.targets, f)

        for attack in attacks :
            result_dict = dict()
            if self._exist(attack) and not overwrite :
                self.results_exp.update(self._load(attack))
                continue
            att_func, num_step = att_helper(attack);
            _, adv_logits, adv_ys = att_func(model=model, dl=dloader,
                     args=args, num_step=num_step)
            if num_step is None :
                result_dict[attack] = dict()
                result_dict[attack]["ys"] = adv_ys[0]
                result_dict[attack]["logits"] = adv_logits[0]
            else :
                for n_idx, (al, ay) in enumerate(zip(adv_logits, adv_ys)) :
                    result_dict[attack +"_{}".format(n_idx+1)] = dict()
                    result_dict[attack + "_{}".format(n_idx+1)]["ys"] = ay
                    result_dict[attack + "_{}".format(n_idx+1)]["logits"] = al
            if not overwrite :
                self._save(attack, result_dict)
            self.results_exp.update(result_dict)
        self.attacks = attacks

    def _file_name(self, postfix, rootdir=None):
        return os.path.join(self.rootdir if rootdir is None else rootdir,
                            "robust {}".format(postfix))

    def _collect_results(self, attack):
        df_results = pd.DataFrame();
        for exp in self.exps :
            ori_labels = np.stack(self.results_all[exp]["orig"]["ys"])
            ori_logits = np.stack(self.results_all[exp]["orig"]["logits"])
            len_data = len(ori_labels);
            ori_acc = (np.argmax(ori_logits, -1) == ori_labels).sum() / len_data
            robust_accs = [ori_acc * 100.0]
            robust_steps = [0]
            params = attack.split("-");
            if len(params) == 2:
                nsteps = int(params[1]);
                for nstep in range(nsteps):
                    adv_logits = np.stack(self.results_all[exp][attack + "_{}".format(nstep + 1)]["logits"]);
                    adv_acc = (np.argmax(adv_logits, -1) == ori_labels).sum() / len_data
                    robust_accs.append(adv_acc * 100.0)
                    robust_steps.append(nstep + 1)
            else:
                adv_logits = np.stack(self.results_all[exp][attack]["logits"]);
                adv_acc = (np.argmax(adv_logits, -1) == ori_labels).sum() / len_data
                robust_accs.append(adv_acc * 100.0)
                robust_steps.append(1)
            exp_results = {
                "exp" : [exp]*len(robust_accs),
                "acc (%)" : robust_accs,
                "steps" : robust_steps,
            }
            df_results = pd.concat([df_results,pd.DataFrame(exp_results)], ignore_index=True)
        return df_results

    def plot(self, rootdir=None):
        results = []
        for attack in self.attacks :
            df_results = self._collect_results(attack);
            results.append(df_results['acc (%)'])
            plt.plot(df_results.index, df_results['acc (%)'], label=attack.split("-")[0])            
        plt.legend()
        plt.xlabel('number of iterations')
        plt.ylabel('robusc acc.')
        plt.savefig(self._file_name("robust", rootdir), dpi=self._dpi)
        results = pd.concat(results, axis=1)
        results.to_csv(os.path.join(self.rootdir, "robust.csv"))            

    def print(self,):
        scores = []
        for attack in self.attacks :
            df_results = self._collect_results(attack);
            for exp in self.exps :
                df_exp = df_results.loc[df_results.exp==exp]
                # for step in df_exp.steps.unique() :
                    # print_str += "{} - {:.3f} | ".format(step,
                                                        #  np.mean(df_exp.loc[df_exp.steps==step]["acc (%)"].values))
                print_str = f"attack : {attack} | acc : {df_exp['acc (%)'][len(df_exp)-1]}"
                print(print_str)
                scores.append(df_results['acc (%)'].iloc[-1])
        return scores

    def to_dict(self,):
        for attack in self.attacks :
            df_results = self._collect_results(attack);
            for exp in self.exps :
                print_str = "{} on {} : ".format(attack, exp)
                df_exp = df_results.loc[df_results.exp==exp]
                for step in df_exp.steps.unique() :
                    print_str += "{} - {:.3f} | ".format(step,
                                                         np.mean(df_exp.loc[df_exp.steps==step]["accs"].values))
                print(print_str)

    def to_csv(self, rootdir=None):
        with open(os.path.join(rootdir, "robustness.csv"), "w", newline='') as f :
            wr = csv.writer(f);
            for e_idx, exp in enumerate(self.exps) :
                exp_dict = self.results_all[exp]
                keys = exp_dict.keys();
                if e_idx == 0 :
                    wr.writerow([" "] + list(keys))
                row_vals = [exp];
                for key in keys :
                    logits = np.stack(exp_dict[key]["logits"]);
                    labels = np.stack(exp_dict[key]["ys"])
                    acc = (np.argmax(logits, -1) == labels).mean()*100.0
                    row_vals.append(acc)
                wr.writerow(row_vals)

    def ad_detect(self, text_log):
        softmax = nn.Softmax(dim=1)
        steps = self.attacks[0].split("-")[1]
        exp = self.exps[0]
        results = self.results_all[exp]

        text_log['adv. auc'] = dict()
        text_log['adv. aupr'] = dict()

        orig_logits = results["orig"]["logits"]
        orig_logits = np.stack(orig_logits)
        orig_logits = torch.from_numpy(orig_logits)
        orig_probs = softmax(orig_logits).max(1)[0]
        for attack in self.attacks:
            attack_logits = results[attack + f'_{steps}']['logits']
            attack_logits = np.stack(attack_logits)
            attack_logits = torch.from_numpy(attack_logits)
            attack_probs = softmax(attack_logits).max(1)[0]

            # auc, aupr
            np_attack_logits = attack_logits.logsumexp(1).detach().cpu().numpy()
            np_orig_logits = orig_logits.logsumexp(1).detach().cpu().numpy()
            np_attack_probs = softmax(attack_logits).max(1)[0].detach().cpu().numpy()
            np_orig_probs = softmax(orig_logits).max(1)[0].detach().cpu().numpy()
            fake_labels = np.zeros_like(np_attack_logits)
            real_labels = np.ones_like(np_orig_logits)
            
            np_labels = np.concatenate([fake_labels, real_labels])
            np_logits_scores = np.concatenate([np_attack_logits, np_orig_logits])
            np_probs_scores = np.concatenate([np_attack_probs, np_orig_probs])
            auc_p_x = sklearn.metrics.roc_auc_score(np_labels, np_logits_scores)
            auc_p_y_x = sklearn.metrics.roc_auc_score(np_labels, np_probs_scores)
            aupr_p_x = sklearn.metrics.average_precision_score(np_labels, np_logits_scores)
            aupr_p_y_x = sklearn.metrics.average_precision_score(np_labels, np_probs_scores)
            print(f"attack : {attack}")
            print(f"auc p_x : {auc_p_x}")
            print(f"auc p_y|x : {auc_p_y_x}")
            print(f"aupr p_x : {aupr_p_x}")
            print(f"aupr p_y|x : {aupr_p_y_x}")

            text_log['adv. auc'][attack + '_p_x'] = get_valid_unit(auc_p_x*100)
            text_log['adv. auc'][attack + '_p_y|x'] = get_valid_unit(auc_p_y_x*100)
            text_log['adv. aupr'][attack + '_p_x'] = get_valid_unit(aupr_p_x*100)
            text_log['adv. aupr'][attack + '_p_y|x'] = get_valid_unit(aupr_p_y_x*100)

            attack_name = attack.split('-')[0]
            plt.title(f'ad_detect {attack_name} p_y|x', fontsize=12)
            plt.hist(orig_probs, bins=100, alpha=0.5, label='original')
            plt.hist(attack_probs, bins=100, alpha=0.5, label=attack.split('-')[0])
            plt.legend(fontsize=12, loc='upper left')
            plt.xticks(fontsize=12)
            plt.yticks(fontsize=12)
            plt.xlabel('confidence', fontsize=12)
            plt.ylabel('count', fontsize=12)
            plt.savefig(os.path.join(self.rootdir, f'ad_detect_{attack}_p_y|x.png'))
            plt.clf()
            wandb.log({
                f'ad_detect/{attack_name}_p_y|x': \
                    wandb.Image(os.path.join(self.rootdir, f'ad_detect_{attack}_p_y|x.png')),
            })

            plt.title(f'ad_detect {attack_name} p_x', fontsize=12)
            plt.hist(orig_logits.logsumexp(1), bins=100, alpha=0.5, label='original')
            plt.hist(attack_logits.logsumexp(1), bins=100, alpha=0.5, label=attack.split('-')[0])
            plt.legend(fontsize=12, loc='upper left')
            plt.xticks(fontsize=12)
            plt.yticks(fontsize=12)
            plt.xlabel('logits', fontsize=12)
            plt.ylabel('count', fontsize=12)
            plt.savefig(os.path.join(self.rootdir, f'ad_detect_{attack}_p_x.png'))
            plt.clf()
            wandb.log({
                f'ad_detect/{attack_name}_p_x': \
                    wandb.Image(os.path.join(self.rootdir, f'ad_detect_{attack}_p_x.png')),
            })